Skip to main content

Diffusing Away From GANs and Transformers

Investigating the math and code behind the current hype surrounding diffusion models and exploring their effectiveness, applicability, and drawbacks.
Created on June 7|Last edited on June 13

1. Setting the Background

The current boom of Machine Learning is primarily brought about by Deep Learning. Deep Learning is the method of representing data in a N-dimensional space (usually not humanly comprehensible) and learning to get some useful output from this representation, such as classifying a cat or a dog. This representation can not only classify but also generate images, bounding boxes, segmentation masks, etc.
In this article, we will discover the common generative models, their advantages, drawbacks, and how diffusion models, the new hype, change the game for generative networks. Just as a warning, there is going to be some mathematics involved here but I'll make sure to do some hand-holding for those who don’t prefer the complex equations.
To get started, let's explore the two most famous architectures used for generating images: Autoencoders and Generative Adversarial Networks.

1.1. Autoencoders and their variations

An autoencoder, as the name suggests is an encoder of information. A standard autoencoder model works by compressing the input into a fixed-size embedding. This embedding contains the information required to produce the output from the desired domain distribution.


A standard autoencoder has two networks: the encoder and the decoder. The aforementioned latent representation is generated by the encoder block which is usually a convolutional neural network. Using the information in this hidden representation, the decoder, also usually a convolutional neural network, can produce the desired output such as a denoised image, segmentation mask, etc.
A common problem with the original autoencoder was that the representation (z) is not constrained which can produce undesired results and destabilize training. This led to the introduction of variational autoencoders which are the most typical type of architectures of autoencoders found today. Variational autoencoders enforce a constraint on the mean and variance of the hidden representation which leads to less noisy images and better convergence.
In spite of these changes, autoencoders are still notorious for producing noisy images. This was changed by the Generative Adversarial Networks which hold the current state-of-the-art in high-quality image generation.

1.2. Generative Adversarial Networks

The generative adversarial network (or GAN in short) is the second most well-known architecture. GANs are famous for producing high fidelity outputs for various tasks such as image super-resolution, segmentation mask generation, image translation, and inpainting.

A generative adversarial network consists of two neural networks: the generator and the discriminator. The generator (usually a convolutional neural net) takes random noise as input and produces an image of the desired domain. The discriminator is tasked to differentiate if the generated image is real or not i.e. if the output belongs to the desired distribution or not.
These networks usually require large compute resources for training but the main drawback of these networks is mode collapse. This is a situation where the generator figures out how to exploit the discriminator's bias and produces outputs of a distribution that can successfully fool the discriminator. When analyzed, these generated images appear to have some common recurring patterns.
Many techniques have been employed to resolve these problems but the training uncertainties and expense usually deters research in this field. To improve on this, some researchers revisited the previously researched fundamentals of diffusion to propose a method of training models that reduced the requirements of large compute resources while at the same time producing better results than GANs in many cases.

2. What is Diffusion?

Diffusion models are based on the well researched concept of diffusion in physics.
In this context, diffusion is defined as the process by which an environment attempts to attain homogeneity by altering the potential gradient in response to the introduction of a new element. Diffusion as a notion is based on attaining uniformity in a system.
Diffusion of particles in an environment
But are the states of a diffusion process reversible? Can we identify these newly introduced particles in a homogeneous system? This is exactly what we try to do with diffusion models!
Consider that we have an image: we gradually add noise to the image in extremely small steps till we reach a stage (TT) where the image is completely unrecognizable and becomes purely random noise.

Once the forward "noise addition" chain qq is complete, we use a deep learning model with some trainable parameters θ\theta, to try and recover the image from the noise (denoising phase pp) by estimating the noising chain at every timestep.
Diffusion Task: Gradually add noise to the image in TT steps in the forward process and try to recover the original image from the noisy image at xTx_T in the backward process by tracing the chain backwards.
💡
Diffusion was first introduced in 2015 [1] but was recently revived and developed by the researchers at Stanford and Google Brain. Diffusion models are typically classified into two types: continuous diffusion models and discrete diffusion models. In the forward chain, the former adds Gaussian noise to continuous signals, whilst the latter obfuscates discrete input tokens using a Markov Transition matrix. We'll look at the former in this post, understanding and implementing the equations from the DDPM paper [2] in JAX.

2.1. Forward Pass

The diffusion process is fixed to a Markov chain that gradually adds Gaussian noise to the data according to a variance schedule β1,β2...βT\beta_1, \beta_2...\beta_T where β1<β2...<βT\beta_1<\beta_2...< \beta_T
💡
Let us break this sentence down:

2.1.1. Markov Chain

A Markov chain is a chain of events or states that follow the Markov principle. Markov's principle states that the distribution of a variable at an arbitrary point in the chain is determined only by the distribution of the previous state of the variable.

This means that the state of x1x_1 is only dependent on x0x_0. Similarly, the state of x2x_2 is only dependent on x1x_1 but since x1x_1 is dependent on x0x_0, any arbitrary state in the chain is indirectly dependent on all the states that occur before it.
The Markov's principle derives that the probability of occurrence of a chain of events from x1x_1 to xTx_T, given the first state is as follows:
q(x1:Tx0)=t=1Tq(xtxt1)q(x_{1:T}|x_0) = \prod_{t=1}^{T} q(x_t|x_{t-1})

The probability of a state xtx_t given xt1x_{t-1} in our case is directly determined by the addition of noise since the amount of noise in the image at a given stage is only dependent on how much noise was previously existing.

2.1.2. Addition of Gaussian Noise

As discussed above, we will need to calculate the probability of q(xtxt1)q(x_t|x_{t-1}) for generating an image at a given timestamp TT. For this, we will need to sample some noise and incrementally add it to the image.
Noise obtained from a Gaussian distribution only depends on two factors: the mean and the standard deviation (or variance). By changing these two values, it is possible to generate an infinite number of distributions of noise, one of which can then be added to the image at every step.
This is where the variance schedule β1,β2...βT\beta_1, \beta_2...\beta_T comes into play. For diffusion models, we fix the variance schedule as we move along the chain. The sampling of noise can be at a given state is defined as:
q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_tI)

The above line basically says that we have to generate a Gaussian distribution (N\mathcal{N}) for xtx_t by taking the value of 1βtxt1\sqrt{1-\beta_t} x_{t-1} as the mean and βt\beta_t as the variance for that step. Combining this definition with the previous equation for q(x1:Tx0)q(x_{1:T}|x_0), we can now sample the noise for any given step.


2.1.3. The Reparameterization Trick

For our training task, the model, given the timestamp, is responsible to remove the added noise from the image at that timestamp. To generate a noisy image for the said timestamp, we will need to iterate through the entire chain. This is extremely inefficient because pythonic loops are slow and given a large timestamp, the chain may take too long to iterate over.
To avoid this, we use a reparameterization trick. It uses an approximation to generate the noise at the required timestamp. This trick works because adding two Gaussians also results in a Gaussian. The reparameterized formula is given as below:
αt:=1βt\alpha_t := 1-\beta_t

αˉt:=s=1Tαs\bar\alpha_t := \prod_{s=1}^{T}\alpha_s

q(xtx0)=N(xt;αˉtx0,(1αˉt)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar\alpha_t} x_{0}, (1-\bar\alpha_t)I)

As compared to the previous equation, we can see that we have isolated the variance schedule and pre-calculated the cumulative product of this isolated variable αt\alpha_t. Using this equation, we can now directly sample the noisy image at any time step with just the original image.

2.2. Backward Pass

The backward pass aims to turn the noisy image into the desired domain distribution, whether it be for denoising, image super-resolution, or just about anything else!

2.2.1. Autoencoders are back?

For this task, we can use any model with a large enough capacity. Usually, papers tend to use autoencoders like U-Nets with global attention which are mathematically and experimentally proved to be performant for tasks such as generation and segmentation. The only difference between the U-Net model used for diffusion and a standard attention augmented U-Net is that additional timestamp information is integrated into the model as well. In general, models with increased width reach the desired sample quality faster than models with increased depth [4].

The above diagram compares a standard U-Net to the modified U-Net that integrates the information provided by the timestamp. The timestamp is first embedded into a N-dimensional vector and is then added to every layer in the model so that the model can learn the correlation between the noise and the timestamp and de-noise accordingly.
But wait! Do you notice something weird? The model takes in the timestamp and the noisy image as input and outputs noise ?
Yes! Commonly adopted diffusion models output noise but that doesn't mean you cannot directly output the image. The model's aim is to output the noise distribution it believes is present in the picture, and this is done only for the sake of convenience. If we output the noise, we can simplify the loss calculation which makes the process more understandable.

2.2.2. The Training Loop

The standard equation for the backward pass can be given as:
pθ(xt1xt)=N(xt1;μθ(xt,t),θ(xt,t))p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t),\sum_\theta(x_t, t))

Here, we aim to generate the noise when going from a state xtx_t to xt1x_{t-1} according to the mean μθ\mu_\theta and standard deviation distribution θ\sum_\theta generated by the model. Researchers found that fixing the value of the variance to the value of βt\beta_t helps the model converge better. Though this is still under active experimentation, we will go ahead and assume the output variance to be set as βt\beta_t.
Now that we have defined what we need to do, let us define the loss function. The loss function used for diffusion models is derived from the ELBO loss commonly used with variational autoencoders. This loss defines a lower bound objective and a simplified version of the objective can be given as:
Losssimplified(θ)=Et,x0,ϵ[ϵϵθ(αˉtx0+1αˉtϵ,t)2]ϵN(0,I)Loss_{simplified}(\theta) = \mathbb{E_{t, x_0, \epsilon}}[||\epsilon - \epsilon_\theta(\sqrt{\bar\alpha_t} x_{0}+ \sqrt{1-\bar\alpha_t}\epsilon,t)||^2] \hspace{0.5cm} \epsilon \in \mathcal{N}(0, I)

The descent function takes three inputs: the timestamp (tt), the original image (x0x_0), and some randomly generated Gaussian noise that is to be added to the original image (ϵ\epsilon). The model then generates the noise in the forward-propagation step that it thinks is added to the image and we calculate the mean squared error between the model output noise and the original noise. This loss value is then used to calculate the gradients and backpropagate through the autoencoder model.

2.2.3. The Inference Loop

After successfully completing the training process, we must define an inference loop that can generate new samples for us when provided with Gaussian noise. The general algorithm for sampling is given as follows:

Let us run through a loop of sampling. We first sample a random noise that we assume is the xTx_T step image. Then, we simply loop backward from TT to 11 where we sample the image according to the following formula:
xt1=1αt(xt1αt1αˉtϵθ(xt,t))+σtzx_{t-1} = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}} \epsilon_\theta(x_t, t)) + \sigma_t z

This essentially indicates that we utilize the model's mean and set standard deviation to βt\sqrt{\beta_t}.


2.3. Implementing a Denoising Diffusion Model (Colab Notebook)

Now that we have skimmed over the theory of training diffusion models, let's get to implementing it.


  1. Defining the imports and initializing the run

import jax
import optax
import os
import math
import wandb

import random as r
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import jax.numpy as jnp
import jax.random as random
import flax.linen as nn
from flax.training import train_state
import matplotlib.pyplot as plt

from typing import Callable
from PIL import Image
from tqdm.notebook import tqdm

# Set only 80% of memory to be accessible. This avoids OOM due to pre-allocation.
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.8

# Defining some hyperparameters
NUM_EPOCHS = 10
BATCH_SIZE = 64
NUM_STEPS_PER_EPOCH = 60000//64 # MNIST has 60,000 training samples
USER = "" # Enter your W&B username
PROJECT = "" # Enter your project name

# Initializing W&B run
wandb.init(entity=USER, project=PROJECT)


  1. Defining the Forward Pass

The forward pass algorithm can be written as:
  1. Define the total timesteps (TT) for the chain
  2. Generate β\beta, α\alpha and αˉ\bar\alpha for every tTt \in T
  3. Generate noise according to q(xtx0)=N(xt;αˉtx0,(1αˉt)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar\alpha_t} x_{0}, (1-\bar\alpha_t)I)
# Defining a constant value for T
timesteps = 200

# Defining beta for all t's in T steps
beta = jnp.linspace(0.0001, 0.02, timesteps)

# Defining alpha and its derivatives according to reparameterization trick
alpha = 1 - beta
alpha_bar = jnp.cumprod(alpha, 0)
sqrt_alpha_bar = jnp.sqrt(alpha_bar)
one_minus_sqrt_alpha_bar = jnp.sqrt(1 - alpha_bar)

# Implement noising logic according to reparameterization trick
def forward_noising(key, x_0, t):
noise = random.normal(key, x_0.shape)
reshaped_sqrt_alpha_bar_t = jnp.reshape(jnp.take(sqrt_alpha_bar, t), (-1, 1, 1, 1))
reshaped_one_minus_sqrt_alpha_bar_t = jnp.reshape(jnp.take(one_minus_sqrt_alpha_bar, t), (-1, 1, 1, 1))
noisy_image = reshaped_sqrt_alpha_bar_t * x_0 + reshaped_one_minus_sqrt_alpha_bar_t * noise
return noisy_image, noise

# Let us visualize the output image at a few timestamps
fig = plt.figure(figsize=(15, 30))

for index, i in enumerate([10, 50, 100, 185]):
noisy_im, noise = forward_noising(random.PRNGKey(0), jnp.expand_dims(sample_mnist, 0), jnp.array([i,]))
plt.subplot(1, 4, index+1)
plt.imshow(jnp.squeeze(jnp.squeeze(noisy_im, -1),0), cmap='gray')

plt.show()

As we can see, the number gets progressively difficult to identify as T increases. At t=185t=185, the number is almost completely indistinguishable from the added noise.


  1. Defining the Model

We will be using an attention-augmented UNet architecture for our task. As discussed before, the model takes an additional time embedding to capture the correlation between the timestamp and the amount of noise added to the image.
Before we define the model itself, let us define how the time must be embedded into the model. We use the popular sinusoidal projection which is also commonly used in positional encodings in transformers. We project the time constant into a defined dimensional space (in our case, 128 dimensional) which we will integrate into the model later. Let us code this:
class SinusoidalEmbedding(nn.Module):
dim: int = 32
@nn.compact
def __call__(self, inputs):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = inputs[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], -1)
return emb


class TimeEmbedding(nn.Module):
dim: int = 32
@nn.compact
def __call__(self, inputs):
time_dim = self.dim * 4
se = SinusoidalEmbedding(self.dim)(inputs)
# Projecting the embedding into a 128 dim space
x = nn.Dense(time_dim)(se)
x = nn.gelu(x)
x = nn.Dense(time_dim)(x)
return x
The first building block for the UNet is the attention mechanism. The attention that we will be using is the standard dot-product attention with eight heads.
class Attention(nn.Module):
dim: int
num_heads: int = 8
use_bias: bool = False
kernel_init: Callable = nn.initializers.xavier_uniform()

@nn.compact
def __call__(self, inputs):
batch, h, w, channels = inputs.shape
inputs = inputs.reshape(batch, h*w, channels)
batch, n, channels = inputs.shape
scale = (self.dim // self.num_heads) ** -0.5
qkv = nn.Dense(
self.dim * 3, use_bias=self.use_bias, kernel_init=self.kernel_init
)(inputs)
qkv = jnp.reshape(
qkv, (batch, n, 3, self.num_heads, channels // self.num_heads)
)
qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]

attention = (q @ jnp.swapaxes(k, -2, -1)) * scale
attention = nn.softmax(attention, axis=-1)

x = (attention @ v).swapaxes(1, 2).reshape(batch, n, channels)
x = nn.Dense(self.dim, kernel_init=nn.initializers.xavier_uniform())(x)
x = jnp.reshape(x, (batch, int(x.shape[1]** 0.5), int(x.shape[1]** 0.5), -1))
return x
Next, we will be defining the ResNet block. This ResNet block is only slightly different from the original ResNet block because it also incorporates a time embedding.
class Block(nn.Module):
dim: int = 32
groups: int = 8

@nn.compact
def __call__(self, inputs):
conv = nn.Conv(self.dim, (3, 3))(inputs)
norm = nn.GroupNorm(num_groups=self.groups)(conv)
activation = nn.silu(norm)
return activation


class ResnetBlock(nn.Module):
dim: int = 32
groups: int = 8

@nn.compact
def __call__(self, inputs, time_embed=None):
x = Block(self.dim, self.groups)(inputs)
if time_embed is not None:
time_embed = nn.silu(time_embed)
time_embed = nn.Dense(self.dim)(time_embed)
x = jnp.expand_dims(jnp.expand_dims(time_embed, 1), 1) + x
x = Block(self.dim, self.groups)(x)
res_conv = nn.Conv(self.dim, (1, 1), padding="SAME")(inputs)
return x + res_conv
Finally, we will implement the UNet. The UNet will have four upsampling and four downsampling blocks.
class UNet(nn.Module):
dim: int = 8
dim_scale_factor: tuple = (1, 2, 4, 8)
num_groups: int = 8

@nn.compact
def __call__(self, inputs):
inputs, time = inputs
channels = inputs.shape[-1]
x = nn.Conv(self.dim // 3 * 2, (7, 7), padding=((3,3), (3,3)))(inputs)
time_emb = TimeEmbedding(self.dim)(time)
dims = [self.dim * i for i in self.dim_scale_factor]
pre_downsampling = []
# Downsampling phase
for index, dim in enumerate(dims):
x = ResnetBlock(dim, self.num_groups)(x, time_emb)
x = ResnetBlock(dim, self.num_groups)(x, time_emb)
att = Attention(dim)(x)
norm = nn.GroupNorm(self.num_groups)(att)
x = norm + x
# Saving this output for residual connection with the upsampling layer
pre_downsampling.append(x)
if index != len(dims) - 1:
x = nn.Conv(dim, (4,4), (2,2))(x)
# Middle block
x = ResnetBlock(dims[-1], self.num_groups)(x, time_emb)
att = Attention(dim)(x)
norm = nn.GroupNorm(self.num_groups)(att)
x = norm + x
x = ResnetBlock(dims[-1], self.num_groups)(x, time_emb)
# Upsampling phase
for index, dim in enumerate(reversed(dims)):
x = jnp.concatenate([pre_downsampling.pop(), x], -1)
x = ResnetBlock(dim, self.num_groups)(x, time_emb)
x = ResnetBlock(dim, self.num_groups)(x, time_emb)
att = Attention(dim)(x)
norm = nn.GroupNorm(self.num_groups)(att)
x = norm + x
if index != len(dims) - 1:
x = nn.ConvTranspose(dim, (4,4), (2,2))(x)

# Final ResNet block and output convolutional layer
x = ResnetBlock(dim, self.num_groups)(x, time_emb)
x = nn.Conv(channels, (1,1), padding="SAME")(x)
return x


  1. Training Loop

Here, we define the training functions and loops in JAX.
According to the formula that we studied previously, the gradient descent step takes the model generated noisy image, the original noise, and the timestamp and returns the loss.
# Calculate the gradients and loss values for the specific timestamp
@jax.jit
def apply_model(state, noisy_images, noise, timestamp):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(params):
pred_noise = model.apply({'params': params}, [noisy_images, timestamp])
loss = jnp.mean((noise - pred_noise) ** 2)
return loss

grad_fn = jax.value_and_grad(loss_fn, has_aux=False)
loss, grads = grad_fn(state.params)
return grads, loss

# Helper function for applying the gradients to the model
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
The training step performs the following functions:
  1. Generate random PRNGKeys for generating the timestamps and noise
  2. Generate the noisy images
  3. Forward propagate on the UNet
  4. Update the model weights in the backward propagation process according to the calculated gradients
  5. Display loss at that particular step and return the current state and loss
# Define the training step
def train_epoch(epoch_num, state, train_ds, batch_size, rng):
epoch_loss = []
num_steps_elapsed = epoch_num * NUM_STEPS_PER_EPOCH

for index, batch_images in enumerate(tqdm(train_ds)):
rng, tsrng = random.split(rng)
timestamps = random.randint(tsrng,
shape=(batch_images.shape[0],),
minval=0, maxval=timesteps)
noisy_images, noise = forward_noising(rng, batch_images, timestamps)
grads, loss = apply_model(state, noisy_images, noise, timestamps)
state = update_model(state, grads)
epoch_loss.append(loss)
wandb.log({"train_loss": loss, 'step': num_steps_elapsed + (index + 1)})
if index % 10 == 0:
print(f"Loss at step {index}: ", loss)
# Timestamps are not needed anymore. Saves some memory.
del timestamps
train_loss = np.mean(epoch_loss)

return state, train_loss
We will create two helper functions for loading the dataset and creating the training state for the model.
# Load and preprocess the MNIST data
def get_datasets():
ds = tfds.load('mnist', as_supervised=True)
train_ds, test_ds = ds['train'], ds['test']

def preprocess(x, y):
return tf.image.resize(tf.cast(x, tf.float32) / 127.5 - 1, (32, 32))

train_ds = train_ds.map(preprocess, tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess, tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(5000).batch(BATCH_SIZE)
test_ds = test_ds.batch(BATCH_SIZE)

return tfds.as_numpy(train_ds), tfds.as_numpy(test_ds)

# Creating a train state for our Flax UNet
def create_train_state(rng):
"""Creates initial `TrainState`."""
params = model.init(rng, [jnp.ones([1, 32, 32, 1]), jnp.ones([1,])])['params']
tx = optax.adam(1e-4)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
Before we start the training process, we will define the logic for training which goes as follows:
  1. Generate a PRNGKey which will be used to initialize the weights
  2. Create a training state for our model using the helper function defined before
  3. Iterate over NUM_EPOCHS and for each epoch, call the train_epoch() function.
  4. Log the state at the end of the epoch for future reference (This is optional)
log_state = []

def train(train_ds) -> train_state.TrainState:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng)

for epoch in range(1, NUM_EPOCHS + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss = train_epoch(epoch, state, train_ds, BATCH_SIZE, input_rng)
print("Training loss: ", train_loss)
log_state.append(state)
return state
All that is left to do now is load the dataset and call this train function:
train_ds, test_ds = get_datasets()
# This function will return the final trained state after `NUM_EPOCHS` epochs
trained_state = train(train_ds)
Let us log and monitor the training loss using Weights and Biases!




  1. Inference Loop

Now that we have trained our model, let us implement a helper function that can take randomly initialized noise and convert it into something that belongs to the input distribution and is more recognizable.
# This function defines the logic of getting x_t-1 given x_t
def backward_denoising(x_t, pred_noise, t):
alpha_t = jnp.take(alpha, t)
alpha_t_bar = jnp.take(alpha_bar, t)
eps_coef = (1 - alpha_t) / (1 - alpha_t_bar) ** .5
mean = 1 / (alpha_t ** 0.5) * (x_t - eps_coef * pred_noise)
var = jnp.take(beta, t)
eps = random.normal(key=random.PRNGKey(r.randint(1, 100)), shape=x_t.shape)
return mean + (var ** 0.5) * eps
To use this function, we need a random noise, the model prediction for xt1x_{t-1} and the timestamp tt. Let us see the code for generating these
# Generating Gaussian noise
x = random.normal(random.PRNGKey(42), (1, 32, 32, 1))
img_list = []

for i in range(0, timesteps):
t = jnp.expand_dims(jnp.array(timesteps - i - 1, jnp.int32), 0)
pred_noise = model.apply({'params': trained_state.params}, [x, t])
x = backward_denoising(x, pred_noise, t)

# Log the image after every 25 iterations
if i % 25 == 0:
img_list.append(jnp.squeeze(jnp.squeeze(x, 0),-1))

# Generate a GIF from the logged images
imgs = (Image.fromarray((np.array(i) * 127.5) + 1) for i in img_list)
img = next(imgs) # extract first image from iterator
img.save(fp=f"output.gif", format='GIF', append_images=imgs,
save_all=True, duration=200, loop=0)

# Log the GIF to W&B
wandb.log({"Reconstruction-GIFs": wandb.Image(f"output.gif")})


Note: The loop above is un-optimized and the image generation may take around 3-4 minutes depending on the accelerator.
💡

2.4. Using Diffusion Models with text prompts

The rapid growth that the prompt based image generation models like DALLE-2 [10] and Imagen [11] have seen can be credited to the rise of diffusion models. Prompt based generation is defined as the process of generating a viable, high quality image given a text description or class for which the image is to be generated. Some examples are given below:

These models are particularly impressive since they are capable of understanding visual context and producing high-resolution outputs. For example, consider the picture in the middle: for the model to produce such an image, the model must understand what a dog is, what a cat is, what a mirror is, along with the concept of reflection in a mirror. Apart from that, if you view the image closely, you will be able to see that the model also blurs the reflection slightly, giving it a realistic touch. This text-based generation is possible due to the integration of textual embeddings into the model.

Continuing with the same diagram as before, we can add a label or text prompt which can be embedded into the model either by using a pretrained NLP model's outputs or using a learned embedding. Though this method works well, it struggles to understand and generate the sequentiality of text based images such as sign boards or warnings.
This led to the introduction of classifier free guidance. In this method, the model samples the output image once with the text and once without.

The scaled difference is taken in direction of the text utilizing output
It is obvious that the text sampled vector, the former, is more effective than the latter so the authors [7] take the vectors of both the methods, take their difference, and scale this difference by a predefined scaling factor in the direction of the former. Training using this method helps the model grasp the missing sequentality and context.
Ongoing research attempts to find better methods to embed the textual information within the model to help the model with sequentiality and better image generation.

3. Improvements in Diffusion Models

Diffusion models are far from perfect! With every paper, diffusion models get better and better. Let us discuss two ways in which the original diffusion model used for denoising was improved.

3.1. Approximation

Usually, for these models to work, the value of TT must be set to a high number such as 1000 or 2000. This makes the inference longer and computationally heavier. Recent works have brought down this figure to just 25-50 steps [3][4] or even just 10 steps as in the case of vector quantized diffusion models [9]! This is done by using reparameterization tricks or by analytically skipping steps during the backward pass.
For understanding this reduction, let us revisit the equation used in the descent function:
xt1=αˉt1x0+1αˉt1zt1zt1,zt2...N(0,I)x_{t-1} = \sqrt{\bar\alpha_{t-1}} x_0 + \sqrt{1-\bar\alpha_{t-1}}z_{t-1} \hspace{0.5cm} z_{t-1}, z_{t-2}... \in \mathcal{N}(0, I)

=αˉt1x0+1αˉt1σt2zt+σtz= \sqrt{\bar\alpha_{t-1}} x_0 + \sqrt{1-\bar\alpha_{t-1}\sigma_t^2}z_t + \sigma_tz

=αˉt1x0+1αˉt1σt2(xtαˉtx01αˉt)+σtz= \sqrt{\bar\alpha_{t-1}} x_0 + \sqrt{1-\bar\alpha_{t-1}\sigma_t^2}(\frac{x_t - \sqrt{\bar\alpha_t}x_0}{\sqrt{1-\bar\alpha_t}}) + \sigma_tz

qσ(xt1xt,x0)=N(xt1;αˉt1x0+1αˉt1σt2(xtαˉtx01αˉt),σt2I)q_\sigma(x_{t-1}|x_t,x_0) = \mathcal{N}(x_{t-1}; \sqrt{\bar\alpha_{t-1}}x_0 + \sqrt{1-\bar\alpha_{t-1} -\sigma^2_t}(\frac{x_t - \sqrt{\bar\alpha_t}x_0}{\sqrt{1-\bar\alpha_t}}), \sigma_t^2I)

Comparing this with the original equation q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI)q(x_{t-1}|x_t,x_0) = \mathcal{N}(x_{t-1}; \tilde\mu(x_t,x_0), \tilde\beta_tI) we get:
β~t=σt2=1αˉt11αˉt.βt\tilde\beta_t = \sigma_t^2 = \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t} . \beta_t

Let σt2=η.β~t\sigma_t^2 = \eta . \tilde\beta_t where η\eta is treated as a hyperparameter. If we set η\eta as 1, we get the standard DDPM but if we set η\eta to 0, we can make the sampling process deterministic. This is exactly what the authors did in the Denoising Diffusion Implicit Models (DDIM) paper [3].
During the generation, we sample for {τ1,...,τs}\{\tau_1,...,\tau_s\} where S<TS < T. With this technique of setting η\eta as 0, it is possible to get perceptually cleaner, higher quality images with 100 or lesser steps when the model was trained for over 1000!

3.2. Latent Diffusion

Another proposition to optimize the training is to reduce the size of the image on which diffusion is performed. This is called latent diffusion.

With latent diffusion, we can avoid processing large 512 ×\times 512px or 1024 ×\times1024px images and can instead shrink these to a friendlier size of 28×\times28px or 32×\times32px. Diffusion is then applied to this downsampled representation whose output is then upsampled again. This, when combined with the approximation trick, can sample very high-resolution images comparatively quickly.
As research on diffusion models progresses, we will soon get to see more optimizations and tricks to increase image quality and reduce training overheads.

4. Advantages of Diffusion

Even though diffusion is a new domain for most practitioners, it has some clear advantages over previous methods.
  1. Less parameter tuning and stable training: As compared to GANs, diffusion models don't require training stabilization tricks.
  2. Faster training with optimized diffusion techniques: As discussed above, using techniques like latent diffusion along with approximation can lead to much faster inference times as well.
  3. High fidelity outputs: The perceptibility and quality of the generated image is superior to that of GANs and autoencoders.

5. References

  1. GLID-3 (Alex Nichol et al, 2022)
  2. Improved Vector Quantized Diffusion Models (Zhicong Tang et al, 2022)

If you liked reading this blog, consider following me on Twitter where I share more such content!